-
Notifications
You must be signed in to change notification settings - Fork 499
[Feat] Simplify and unify model loading - from_pretrained #1915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@@ -76,8 +76,6 @@ def __init__( | |||
self.cross_attention = MultiHeadAttention(num_heads, d_model, dropout=dropout) | |||
self.position_feed_forward = PositionwiseFeedForward(d_model, ffd * ffd_ratio, dropout, nn.GELU()) | |||
|
|||
self.attention_norm = nn.LayerNorm(d_model, eps=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused layers
@@ -76,8 +76,6 @@ def __init__( | |||
d_model, ffd * ffd_ratio, dropout, layers.Activation(tf.nn.gelu) | |||
) | |||
|
|||
self.attention_norm = layers.LayerNormalization(epsilon=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused layers
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1915 +/- ##
==========================================
- Coverage 96.81% 96.80% -0.01%
==========================================
Files 172 172
Lines 8444 8515 +71
==========================================
+ Hits 8175 8243 +68
- Misses 269 272 +3
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
@SiddhantBahuguna maybe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MAGC doesn't needs it's own from_pretrained
method because it's only a Layer plugged into the ResNet implementation
@@ -219,10 +219,10 @@ jobs: | |||
pip install -e .[torch,viz,html] --upgrade | |||
- if: matrix.framework == 'tensorflow' | |||
name: Evaluate text recognition (TF) | |||
run: python references/recognition/evaluate_tensorflow.py crnn_mobilenet_v3_small --dataset IIIT5K -b 32 | |||
run: python references/recognition/evaluate_tensorflow.py crnn_mobilenet_v3_small --dataset SVT -b 32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIIT5K download url is down ..again .. wrote already an mail .. use another dataset for the CI test
@felixdittrich92 Edit: misread the implementation. But it does feel confusing that |
@sebastianMindee So the idea of the wrapper method was that we can use this place for custom model logic (like btw we wouldn't miss to implement this method because if we add the corresponding test entries the CI would fail (Added a unittest to check that all models have the |
After double checking I have to revert the above case because it would result with the same issue as modifying And I'm not open to switch to full lazy loading by skipping anything non matching 😅 from ..recognition import PARSeq # circular
def load_pretrained_params(
model: nn.Module,
path_or_url: str | None = None,
hash_prefix: str | None = None,
ignore_keys: list[str] | None = None,
**kwargs: Any,
) -> None:
"""Load a set of parameters onto a model
>>> from doctr.models import load_pretrained_params
>>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.pt")
Args:
model: the PyTorch model to be loaded
path_or_url: the path or URL to the model parameters (checkpoint)
hash_prefix: first characters of SHA256 expected hash
ignore_keys: list of weights to be ignored from the state_dict
**kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
"""
if path_or_url is None:
logging.warning("No model URL or Path provided, using default initialization.")
return
archive_path = (
download_from_url(path_or_url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
if validators.url(path_or_url)
else path_or_url
)
if isinstance(model, PARSeq):
# NOTE: This is required to make the model backward compatible with already trained models docTR version <0.11.1
# ref.: https://github.com/mindee/doctr/issues/1911
if ignore_keys is None:
ignore_keys = []
ignore_keys.extend([
"decoder.attention_norm.weight",
"decoder.attention_norm.bias",
"decoder.cross_attention_norm.weight",
"decoder.cross_attention_norm.bias",
])
... An option would be to check only for the name 🙈
|
Ok finally: Option 1: model = vitstr()
model.from_pretrained("path/to/weights.X") We keep it as provided in the PR Advantage:
Disadvantage:
Option 2: We handle anything in on user side then it would be: model = vitstr()
load_pretrained_params(model, "path/to/weights.X") Advantage:
Disadvantage:
@sebastianMindee @SiddhantBahuguna Now we should decide which option would you prefer ? 😄 My vote goes to option 1 because from user view it feels a bit smarter 😅 AND in this case we have well unittested that each model provides a defined method to load weights - this case provides several checks - |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but I'd like to wait for @SiddhantBahuguna's approval before giving it the official go 🙂
Yeah that's fine |
Will update the missing part for #1912 next week :) |
f1867c6
to
c3fb81e
Compare
Done :) |
c3fb81e
to
0c5b02e
Compare
soft breaking change
from_pretrained
should be used instead of manual loadingparseq
it would fail to load "old" weights after the next release by loading manual instead offrom_pretrained
This PR:
parseq
by keeping it backward compatible - for already trained models ref.: Unused parameters in parseq. #1911design:
from_pretrained
method -> unittestAny feedback is welcome 🤗