Skip to content

Commit 8cc0064

Browse files
authored
Merge branch 'ollama:main' into main
2 parents 88936d5 + 98d44fa commit 8cc0064

File tree

193 files changed

+19327
-11302
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

193 files changed

+19327
-11302
lines changed

.github/workflows/test.yaml

+15-16
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ jobs:
7878
include:
7979
- preset: CPU
8080
- preset: CUDA
81-
install: https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_522.06_windows.exe
82-
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
81+
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
82+
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
8383
- preset: ROCm
8484
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
8585
flags: '-DAMDGPU_TARGETS=gfx1010'
@@ -102,7 +102,7 @@ jobs:
102102
$ErrorActionPreference = "Stop"
103103
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
104104
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
105-
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_11.8", "nvcc_11.8", "cublas_11.8", "cublas_dev_11.8")) -NoNewWindow -Wait
105+
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_11.3", "nvcc_11.3", "cublas_11.3", "cublas_dev_11.3")) -NoNewWindow -Wait
106106
}
107107
108108
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
@@ -190,28 +190,27 @@ jobs:
190190

191191
go-version-file: go.mod
192192

193-
# TODO(bmizerany): replace this heavy tool with just the
194-
# tools/checks/binaries we want and then make them all run in parallel
195-
# across jobs, not on a single tiny vm on Github Actions.
196-
- uses: golangci/golangci-lint-action@v6
197-
with:
198-
args: --timeout 10m0s -v
199-
200-
- name: go test
201-
# Do not skip tests in the face of linter errors, or 'go mod tidy'
202-
# checks, which are secondary to the tests. Tests trump linters.
203-
if: always()
204-
run: go test -count=1 -benchtime=1x ./...
205-
206193
# It is tempting to run this in a platform independent way, but the past
207194
# shows this codebase will see introductions of platform specific code
208195
# generation, and so we need to check this per platform to ensure we
209196
# don't abuse go generate on specific platforms.
210197
- name: check that 'go generate' is clean
198+
if: always()
211199
run: |
212200
go generate ./...
213201
git diff --name-only --exit-code || (echo "Please run 'go generate ./...'." && exit 1)
214202
203+
- name: go test
204+
if: always()
205+
run: go test -count=1 -benchtime=1x ./...
206+
207+
# TODO(bmizerany): replace this heavy tool with just the
208+
# tools/checks/binaries we want and then make them all run in parallel
209+
# across jobs, not on a single tiny vm on Github Actions.
210+
- uses: golangci/golangci-lint-action@v6
211+
with:
212+
args: --timeout 10m0s -v
213+
215214
- name: cache save
216215
# Always save the cache, even if the job fails. The artifacts produced
217216
# during the building of test binaries are not all for naught. They can

CMakePresets.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
"name": "CUDA 11",
2222
"inherits": [ "CUDA" ],
2323
"cacheVariables": {
24-
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;62;70;72;75;80;86"
24+
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86"
2525
}
2626
},
2727
{
2828
"name": "CUDA 12",
2929
"inherits": [ "CUDA" ],
3030
"cacheVariables": {
31-
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;62;70;72;75;80;86;87;89;90;90a"
31+
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;100"
3232
}
3333
},
3434
{

CONTRIBUTING.md

+57-6
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ Thank you for your interest in contributing to Ollama! Here are a few guidelines
66

77
See the [development documentation](./docs/development.md) for instructions on how to build and run Ollama locally.
88

9-
## Pull requests
10-
119
### Ideal issues
1210

1311
* [Bugs](https://github.com/ollama/ollama/issues?q=is%3Aissue+is%3Aopen+label%3Abug): issues where Ollama stops working or where it results in an unexpected error.
@@ -26,11 +24,64 @@ See the [development documentation](./docs/development.md) for instructions on h
2624
* Changes that add significant friction to the user experience
2725
* Changes that create a large future maintenance burden for maintainers and contributors
2826

29-
### Best practices
27+
## Proposing a (non-trivial) change
28+
29+
> By "non-trivial", we mean a change that is not a bug fix or small
30+
> documentation update. If you are unsure, please ask us on our [Discord
31+
> server](https://discord.gg/ollama).
32+
33+
Before opening a non-trivial Pull Request, please open an issue to discuss the change and
34+
get feedback from the maintainers. This helps us understand the context of the
35+
change and how it fits into Ollama's roadmap and prevents us from duplicating
36+
work or you from spending time on a change that we may not be able to accept.
37+
38+
Tips for proposals:
39+
40+
* Explain the problem you are trying to solve, not what you are trying to do.
41+
* Explain why the change is important.
42+
* Explain how the change will be used.
43+
* Explain how the change will be tested.
44+
45+
Additionally, for bonus points: Provide draft documentation you would expect to
46+
see if the change were accepted.
47+
48+
## Pull requests
49+
50+
**Commit messages**
51+
52+
The title should look like:
53+
54+
<package>: <short description>
55+
56+
The package is the most affected Go package. If the change does not affect Go
57+
code, then use the directory name instead. Changes to a single well-known
58+
file in the root directory may use the file name.
59+
60+
The short description should start with a lowercase letter and be a
61+
continuation of the sentence:
62+
63+
"This changes Ollama to..."
64+
65+
Examples:
66+
67+
llm/backend/mlx: support the llama architecture
68+
CONTRIBUTING: provide clairity on good commit messages, and bad
69+
70+
Bad Examples:
71+
72+
feat: add more emoji
73+
fix: was not using famous web framework
74+
chore: generify code
75+
76+
**Tests**
77+
78+
Please include tests. Strive to test behavior, not implementation.
79+
80+
**New dependencies**
3081

31-
* Commit messages: please leave both a title and a description in your commit messages. The title should be a short summary of the changes, with a leading word that explains the section of the code being changed (e.g. `api: fix parsing of prompt field`) . In the description, leave a short 2-3 sentences that explain more about the change and its impact.
32-
* Tests: please add test coverage to changes where possible.
33-
* Minimize dependencies: avoid adding new dependencies unless absolutely necessary.
82+
Dependencies should be added sparingly. If you are adding a new dependency,
83+
please explain why it is necessary and what other ways you attempted that
84+
did not work without it.
3485

3586
## Need help?
3687

Makefile.sync

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
UPSTREAM=https://github.com/ggerganov/llama.cpp.git
22
WORKDIR=llama/vendor
3-
FETCH_HEAD=46e3556e01b824e52395fb050b29804b6cff2a7c
3+
FETCH_HEAD=d7cfe1ffe0f435d0048a6058d529daf76e072d9c
44

55
.PHONY: help
66
help:

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
524524
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
525525
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
526526
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
527+
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
527528

528529
### Mobile
529530

api/client.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
// repository].
1111
//
1212
// [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
13-
// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/examples
13+
// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/api/examples
1414
package api
1515

1616
import (

docs/development.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ go run . serve
6969
7070
## Windows (ARM)
7171

72-
Windows ARM does not support additional acceleration libraries at this time.
72+
Windows ARM does not support additional acceleration libraries at this time. Do not use cmake, simply `go run` or `go build`.
7373

7474
## Linux
7575

envconfig/config.go

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ func AllowedOrigins() (origins []string) {
7373
"file://*",
7474
"tauri://*",
7575
"vscode-webview://*",
76+
"vscode-file://*",
7677
)
7778

7879
return origins

envconfig/config_test.go

+4
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ func TestOrigins(t *testing.T) {
6969
"file://*",
7070
"tauri://*",
7171
"vscode-webview://*",
72+
"vscode-file://*",
7273
}},
7374
{"http://10.0.0.1", []string{
7475
"http://10.0.0.1",
@@ -88,6 +89,7 @@ func TestOrigins(t *testing.T) {
8889
"file://*",
8990
"tauri://*",
9091
"vscode-webview://*",
92+
"vscode-file://*",
9193
}},
9294
{"http://172.16.0.1,https://192.168.0.1", []string{
9395
"http://172.16.0.1",
@@ -108,6 +110,7 @@ func TestOrigins(t *testing.T) {
108110
"file://*",
109111
"tauri://*",
110112
"vscode-webview://*",
113+
"vscode-file://*",
111114
}},
112115
{"http://totally.safe,http://definitely.legit", []string{
113116
"http://totally.safe",
@@ -128,6 +131,7 @@ func TestOrigins(t *testing.T) {
128131
"file://*",
129132
"tauri://*",
130133
"vscode-webview://*",
134+
"vscode-file://*",
131135
}},
132136
}
133137
for _, tt := range cases {

fs/ggml/ggml.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ func (kv KV) Float(key string, defaultValue ...float32) float32 {
100100
return keyValue(kv, key, append(defaultValue, 0)...)
101101
}
102102

103+
func (kv KV) Bool(key string, defaultValue ...bool) bool {
104+
return keyValue(kv, key, append(defaultValue, false)...)
105+
}
106+
103107
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
104108
r := keyValue(kv, key, &array{})
105109
s := make([]string, r.size)
@@ -120,7 +124,7 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
120124
return s
121125
}
122126

123-
func keyValue[T string | uint32 | uint64 | float32 | *array](kv KV, key string, defaultValue ...T) T {
127+
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
124128
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
125129
key = kv.Architecture() + "." + key
126130
}

go.mod

+8-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module github.com/ollama/ollama
22

3-
go 1.24
3+
go 1.24.0
44

55
require (
66
github.com/containerd/console v1.0.3
@@ -11,7 +11,7 @@ require (
1111
github.com/spf13/cobra v1.7.0
1212
github.com/stretchr/testify v1.9.0
1313
github.com/x448/float16 v0.8.4
14-
golang.org/x/sync v0.10.0
14+
golang.org/x/sync v0.11.0
1515
)
1616

1717
require (
@@ -69,12 +69,12 @@ require (
6969
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
7070
github.com/ugorji/go/codec v1.2.12 // indirect
7171
golang.org/x/arch v0.8.0 // indirect
72-
golang.org/x/crypto v0.31.0
73-
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa
74-
golang.org/x/net v0.25.0 // indirect
75-
golang.org/x/sys v0.28.0
76-
golang.org/x/term v0.27.0
77-
golang.org/x/text v0.21.0
72+
golang.org/x/crypto v0.33.0
73+
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
74+
golang.org/x/net v0.35.0 // indirect
75+
golang.org/x/sys v0.30.0
76+
golang.org/x/term v0.29.0
77+
golang.org/x/text v0.22.0
7878
google.golang.org/protobuf v1.34.1
7979
gopkg.in/yaml.v3 v3.0.1 // indirect
8080
)

go.sum

+14-14
Original file line numberDiff line numberDiff line change
@@ -214,16 +214,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
214214
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
215215
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
216216
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
217-
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
218-
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
217+
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
218+
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
219219
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
220220
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
221221
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
222222
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
223223
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
224224
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
225-
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ=
226-
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
225+
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
226+
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
227227
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
228228
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
229229
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
@@ -257,8 +257,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
257257
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
258258
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
259259
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
260-
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
261-
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
260+
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
261+
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
262262
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
263263
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
264264
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -268,8 +268,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
268268
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
269269
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
270270
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
271-
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
272-
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
271+
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
272+
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
273273
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
274274
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
275275
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -285,17 +285,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
285285
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
286286
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
287287
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
288-
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
289-
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
288+
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
289+
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
290290
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
291-
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
292-
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
291+
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
292+
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
293293
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
294294
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
295295
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
296296
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
297-
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
298-
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
297+
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
298+
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
299299
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
300300
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
301301
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

kvcache/causal.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,10 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
330330
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
331331
}
332332

333-
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))))
334-
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))))
333+
ctx.Forward(
334+
key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))),
335+
value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))),
336+
)
335337
}
336338

337339
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {

kvcache/causal_test.go

+2-4
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
280280

281281
out, _, mask := cache.Get(context)
282282

283-
context.Forward(out)
284-
context.Forward(mask)
285-
context.Compute(out, mask)
283+
context.Forward(out, mask).Compute(out, mask)
286284

287285
if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
288286
t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
@@ -344,7 +342,7 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
344342
return out, nil
345343
}
346344

347-
func (c *testContext) Forward(ml.Tensor) {}
345+
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
348346

349347
func (c *testContext) Compute(...ml.Tensor) {}
350348

kvcache/encoder.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,10 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
8080
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
8181
}
8282

83-
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer]))
84-
ctx.Forward(value.Copy(ctx, c.values[c.curLayer]))
83+
ctx.Forward(
84+
key.Copy(ctx, c.keys[c.curLayer]),
85+
value.Copy(ctx, c.values[c.curLayer]),
86+
)
8587
}
8688

8789
func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {

0 commit comments

Comments
 (0)