Skip to content

[Feat] Simplify and unify model loading - from_pretrained #1915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from

Conversation

felixdittrich92
Copy link
Contributor

@felixdittrich92 felixdittrich92 commented Mar 28, 2025

soft breaking change

  • next release notes: from_pretrained should be used instead of manual loading
  • especially with parseq it would fail to load "old" weights after the next release by loading manual instead of from_pretrained

This PR:

  • Unify model loading between TF & PT (more control on our end)
  • Simplify custom model loading & update docs
  • The following allows users to load from local file or new also from url
# NEW
model = vitstr(pretrained=False)
model.from_pretrained(...)  # local path or url to .pt or .h5

# Instead of depending on the backend
reco_params = torch.load('<path_to_pt>', map_location="cpu")
reco_model.load_state_dict(reco_params)
# Or
reco_model.load_weights(..)

design:

  • Normally I would prefer a base class with such functionality where each model inerhits from - but this would require a larger refactoring especially for the torchvision / keras imported backbone (classification) models
  • From now every model needs to have the from_pretrained method -> unittest

Any feedback is welcome 🤗

@felixdittrich92 felixdittrich92 self-assigned this Mar 28, 2025
@felixdittrich92 felixdittrich92 added topic: documentation Improvements or additions to documentation type: enhancement Improvement topic: build Related to dependencies and build topic: ci Related to CI module: models Related to doctr.models ext: tests Related to tests folder type: breaking change Introducing a breaking change framework: pytorch Related to PyTorch backend framework: tensorflow Related to TensorFlow backend ext: docs Related to docs folder labels Mar 28, 2025
@@ -76,8 +76,6 @@ def __init__(
self.cross_attention = MultiHeadAttention(num_heads, d_model, dropout=dropout)
self.position_feed_forward = PositionwiseFeedForward(d_model, ffd * ffd_ratio, dropout, nn.GELU())

self.attention_norm = nn.LayerNorm(d_model, eps=1e-5)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused layers

@@ -76,8 +76,6 @@ def __init__(
d_model, ffd * ffd_ratio, dropout, layers.Activation(tf.nn.gelu)
)

self.attention_norm = layers.LayerNormalization(epsilon=1e-5)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused layers

Copy link

codecov bot commented Mar 28, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 96.80%. Comparing base (c4dd472) to head (0c5b02e).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1915      +/-   ##
==========================================
- Coverage   96.81%   96.80%   -0.01%     
==========================================
  Files         172      172              
  Lines        8444     8515      +71     
==========================================
+ Hits         8175     8243      +68     
- Misses        269      272       +3     
Flag Coverage Δ
unittests 96.80% <100.00%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@felixdittrich92
Copy link
Contributor Author

@SiddhantBahuguna maybe from_weights a better method name or something different ? ^^

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MAGC doesn't needs it's own from_pretrained method because it's only a Layer plugged into the ResNet implementation

@@ -219,10 +219,10 @@ jobs:
pip install -e .[torch,viz,html] --upgrade
- if: matrix.framework == 'tensorflow'
name: Evaluate text recognition (TF)
run: python references/recognition/evaluate_tensorflow.py crnn_mobilenet_v3_small --dataset IIIT5K -b 32
run: python references/recognition/evaluate_tensorflow.py crnn_mobilenet_v3_small --dataset SVT -b 32
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIIT5K download url is down ..again .. wrote already an mail .. use another dataset for the CI test

@felixdittrich92 felixdittrich92 marked this pull request as ready for review March 31, 2025 08:39
@sebastianMindee
Copy link
Collaborator

sebastianMindee commented Mar 31, 2025

@felixdittrich92 Since most of the implementations are the exact same, wouldn't it be better to just have a common definition & just import it when needed (ofc keep the different ones in their respective classes)? I feel like having this many duplicated bits is a sure way to forget about it when it matters down the road 😅

Edit: misread the implementation. But it does feel confusing that from_pretrained implementations are all made into methods just to relay params to load_pretrained_params? Why not just add the non-standard implementations to load_pretrained_params instead?

@felixdittrich92
Copy link
Contributor Author

felixdittrich92 commented Mar 31, 2025

@felixdittrich92 Since most of the implementations are the exact same, wouldn't it be better to just have a common definition & just import it when needed (ofc keep the different ones in their respective classes)? I feel like having this many duplicated bits is a sure way to forget about it when it matters down the road 😅

Edit: misread the implementation. But it does feel confusing that from_pretrained implementations are all made into methods just to relay params to load_pretrained_params? Why not just add the non-standard implementations to load_pretrained_params instead?

@sebastianMindee
Yeah in general the cleanest way would be to have a base class which implements already the functionality and all models inerhit from them but this would require a larger refactoring as mentioned and I think would raise an circular import issue.

So the idea of the wrapper method was that we can use this place for custom model logic (like parseq in this case).
But I agree we could also modify load_pretrained_params slighty make it private, import from models.utils and bind it afterwards model.load_pretrained_params = types.MethodTypes(model, '_load_pretrained_params') in this case we would reduce the most boilerplate code ..I tried to avoid this case as much as possible because anyway it feels a bit hacky so I used it only at places where a larger refactoring would be required 😅
wdyt ?

btw we wouldn't miss to implement this method because if we add the corresponding test entries the CI would fail (Added a unittest to check that all models have the from_pretrained method)

@felixdittrich92
Copy link
Contributor Author

felixdittrich92 commented Apr 1, 2025

@sebastianMindee

After double checking I have to revert the above case because it would result with the same issue as modifying load_pretrained_params directly - results in circular import

And I'm not open to switch to full lazy loading by skipping anything non matching 😅

from ..recognition import PARSeq  # circular

def load_pretrained_params(
    model: nn.Module,
    path_or_url: str | None = None,
    hash_prefix: str | None = None,
    ignore_keys: list[str] | None = None,
    **kwargs: Any,
) -> None:
    """Load a set of parameters onto a model

    >>> from doctr.models import load_pretrained_params
    >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.pt")

    Args:
        model: the PyTorch model to be loaded
        path_or_url: the path or URL to the model parameters (checkpoint)
        hash_prefix: first characters of SHA256 expected hash
        ignore_keys: list of weights to be ignored from the state_dict
        **kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
    """
    if path_or_url is None:
        logging.warning("No model URL or Path provided, using default initialization.")
        return

    archive_path = (
        download_from_url(path_or_url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
        if validators.url(path_or_url)
        else path_or_url
    )

    if isinstance(model, PARSeq):
        # NOTE: This is required to make the model backward compatible with already trained models docTR version <0.11.1
        # ref.: https://github.com/mindee/doctr/issues/1911
        if ignore_keys is None:
            ignore_keys = []
        ignore_keys.extend([
            "decoder.attention_norm.weight",
            "decoder.attention_norm.bias",
            "decoder.cross_attention_norm.weight",
            "decoder.cross_attention_norm.bias",
        ])

...

An option would be to check only for the name 🙈

if type(model).__name__ == "PARSeq":

@felixdittrich92
Copy link
Contributor Author

felixdittrich92 commented Apr 1, 2025

Ok finally:

Option 1:

model = vitstr()
model.from_pretrained("path/to/weights.X")

We keep it as provided in the PR

Advantage:

  • Clear structure for each model & class bound logic (no condition required about model type / name)
  • Already known interface from a lot of other libs like transformers / sentence-transformers and so on -> class method .from_pretrained

Disadvantage:

  • Produces some boilerplate code & a wrapper method
  • Uses method binding at places where a ("to class wrapper") refactoring would be required otherwise

Option 2:

We handle anything in load_pretrained_params

on user side then it would be:

model = vitstr()
load_pretrained_params(model, "path/to/weights.X")

Advantage:

  • Avoids all the boilerplate
  • Minimal changes required

Disadvantage:

  • Model name condition bound in load_pretrained_params
  • Feels less familiar to well known libs
  • Provides additional options which could be a bit confusing for a simple loading function (hash_prefix, ignore_keys, ..) which are "masked" as kwargs in option 1

@sebastianMindee @SiddhantBahuguna Now we should decide which option would you prefer ? 😄

My vote goes to option 1 because from user view it feels a bit smarter 😅 AND in this case we have well unittested that each model provides a defined method to load weights - this case provides several checks - pretrained=True would fail if the model does not implement the method (because of the inner model.from_pretrained call if pretrained=True is passed) and the added assert hasattr checks also

Copy link
Collaborator

@sebastianMindee sebastianMindee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but I'd like to wait for @SiddhantBahuguna's approval before giving it the official go 🙂

@felixdittrich92
Copy link
Contributor Author

LGTM but I'd like to wait for @SiddhantBahuguna's approval before giving it the official go 🙂

Yeah that's fine ☺️

@felixdittrich92 felixdittrich92 added this to the 0.12.0 milestone Apr 11, 2025
@felixdittrich92
Copy link
Contributor Author

Will update the missing part for #1912 next week :)

@felixdittrich92 felixdittrich92 linked an issue Apr 12, 2025 that may be closed by this pull request
@felixdittrich92
Copy link
Contributor Author

Done :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ext: docs Related to docs folder ext: tests Related to tests folder framework: pytorch Related to PyTorch backend framework: tensorflow Related to TensorFlow backend module: models Related to doctr.models topic: build Related to dependencies and build topic: ci Related to CI topic: documentation Improvements or additions to documentation type: breaking change Introducing a breaking change type: enhancement Improvement
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unused parameters in parseq.
2 participants