Skip to content

Fixed ChatSession.LoadSession #976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions LLama.Unittest/LLamaContextTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,35 @@ public void TokenizeEmpty()

Assert.Equal(Array.Empty<LLamaToken>(), tokens);
}

[Fact]
public void SaveLoadState()
{
using var state1 = _context.GetState();

var stream = new MemoryStream();
state1.Save(stream);

stream.Position = 0;

using var state2 = LLamaContext.State.Load(stream);

Assert.Equal(state1.Size, state2.Size);
}

[Fact]
public async Task SaveLoadStateAsync()
{
using var state1 = _context.GetState();

var stream = new MemoryStream();
await state1.SaveAsync(stream);

stream.Position = 0;

using var state2 = await LLamaContext.State.LoadAsync(stream);

Assert.Equal(state1.Size, state2.Size);
}
}
}
20 changes: 12 additions & 8 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -466,16 +466,15 @@ public void Dispose()
public class State
: SafeLLamaHandleBase
{
private readonly nuint _size;
/// <summary>
/// Get the size in bytes of this state object
/// </summary>
public nuint Size => _size;
public nuint Size { get; }

internal State(IntPtr memory, nuint size)
: base(memory, true)
{
_size = size;
Size = size;
}

/// <inheritdoc />
Expand All @@ -494,7 +493,8 @@ public async Task SaveAsync(Stream stream)
UnmanagedMemoryStream from;
unsafe
{
from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), checked((long)Size));
var length = (long)Size;
from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), length, length, FileAccess.Read);
}
await from.CopyToAsync(stream);
}
Expand All @@ -508,7 +508,8 @@ public void Save(Stream stream)
UnmanagedMemoryStream from;
unsafe
{
from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), checked((long)Size));
var length = (long)Size;
from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), length, length, FileAccess.Read);
}
from.CopyTo(stream);
}
Expand All @@ -526,7 +527,8 @@ public static async Task<State> LoadAsync(Stream stream)
UnmanagedMemoryStream dest;
unsafe
{
dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), stream.Length);
var length = stream.Length;
dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), length, length, FileAccess.Write);
}
await stream.CopyToAsync(dest);

Expand All @@ -543,11 +545,13 @@ public static State Load(Stream stream)
var memory = Marshal.AllocHGlobal((nint)stream.Length);
var state = new State(memory, (nuint)stream.Length);

UnmanagedMemoryStream dest;
unsafe
{
var dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), stream.Length);
stream.CopyTo(dest);
var length = stream.Length;
dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), length, length, FileAccess.Write);
}
stream.CopyTo(dest);

return state;
}
Expand Down
Loading