Skip to content

Commit 9eef160

Browse files
committed
Removed warmup period
1 parent fa551ac commit 9eef160

File tree

3 files changed

+188
-15
lines changed

3 files changed

+188
-15
lines changed
+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
{
2+
"sigmoid" : { "Vmin" : 0.0, "Vmax" : 20.0, "Vh" : 16.0, "Vc" : 3.0 },
3+
"clamp" : { "Vmin" : 0.0, "Vmax" : 1000.0 },
4+
"weight" : { "min" : 0.25, "max" : 0.75, "noise" : 0.005 },
5+
6+
"threshold" : 40.0,
7+
8+
"time" : {
9+
"dt" : 0.001,
10+
"settling" : 1.5,
11+
"duration" : 2.5
12+
},
13+
14+
"input" : {
15+
"potential" : 7.0,
16+
"noise" : 0.01
17+
},
18+
19+
"RL" : {
20+
"init" : 0.5,
21+
"alpha" : 0.025,
22+
"LTP" : 0.005,
23+
"LTD" : 0.003
24+
},
25+
26+
"Hebbian" : {
27+
"LTP" : 0.0005
28+
},
29+
30+
"CTX" : { "tau" : 0.01, "rest" : -3.0, "noise" : 0.010 },
31+
"STR" : { "tau" : 0.01, "rest" : 0.0, "noise" : 0.001 },
32+
"STN" : { "tau" : 0.01, "rest" : -10.0, "noise" : 0.001 },
33+
"GPi" : { "tau" : 0.01, "rest" : 10.0, "noise" : 0.001 },
34+
"THL" : { "tau" : 0.01, "rest" : -40.0, "noise" : 0.001 },
35+
36+
"gain" : {
37+
"CTX:cog → STR:cog" : 1.0,
38+
"CTX:cog → STR:ass" : 0.2,
39+
"CTX:cog → STN:cog" : 1.0,
40+
"CTX:cog → THL:cog" : 0.4,
41+
42+
"CTX:mot → STR:mot" : 1.0,
43+
"CTX:mot → STR:ass" : 0.2,
44+
"CTX:mot → STN:mot" : 1.0,
45+
"CTX:mot → THL:mot" : 0.4,
46+
47+
"CTX:ass → STR:ass" : 1.0,
48+
49+
"STR:cog → GPi:cog" : -2.0,
50+
"STR:mot → GPi:mot" : -2.0,
51+
"STR:ass → GPi:cog" : -2.0,
52+
"STR:ass → GPi:mot" : -2.0,
53+
54+
"STN:cog → GPi:cog" : 1.0,
55+
"STN:mot → GPi:mot" : 1.0,
56+
57+
"GPi:cog → THL:cog" : -0.5,
58+
"GPi:mot → THL:mot" : -0.5,
59+
60+
"THL:cog → CTX:cog" : 1.0,
61+
"THL:mot → CTX:mot" : 1.0,
62+
63+
"CTX:mot → CTX:mot" : 0.0,
64+
"CTX:cog → CTX:cog" : 0.0,
65+
"CTX:ass → CTX:ass" : 0.0,
66+
67+
"CTX:cog → CTX:ass" : 0.0,
68+
"CTX:ass → CTX:mot" : 0.0,
69+
"CTX:mot → CTX:ass" : 0.0,
70+
"CTX:ass → CTX:cog" : 0.0
71+
}
72+
}
+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
{
2+
"sigmoid" : { "Vmin" : 0.0, "Vmax" : 20.0, "Vh" : 16.0, "Vc" : 3.0 },
3+
"clamp" : { "Vmin" : 0.0, "Vmax" : 1000.0 },
4+
"weight" : { "min" : 0.25, "max" : 0.75, "noise" : 0.005 },
5+
6+
"threshold" : 40.0,
7+
8+
"time" : {
9+
"dt" : 0.001,
10+
"settling" : 1.5,
11+
"duration" : 2.5
12+
},
13+
14+
"input" : {
15+
"potential" : 7.0,
16+
"noise" : 0.01
17+
},
18+
19+
"RL" : {
20+
"init" : 0.5,
21+
"alpha" : 0.025,
22+
"LTP" : 0.005,
23+
"LTD" : 0.003
24+
},
25+
26+
"Hebbian" : {
27+
"LTP" : 0.0005
28+
},
29+
30+
"CTX" : { "tau" : 0.01, "rest" : -3.0, "noise" : 0.010 },
31+
"STR" : { "tau" : 0.01, "rest" : 0.0, "noise" : 0.001 },
32+
"STN" : { "tau" : 0.01, "rest" : -10.0, "noise" : 0.001 },
33+
"GPi" : { "tau" : 0.01, "rest" : 10.0, "noise" : 0.001 },
34+
"THL" : { "tau" : 0.01, "rest" : -40.0, "noise" : 0.001 },
35+
36+
"gain" : {
37+
"CTX:cog → STR:cog" : 1.0,
38+
"CTX:cog → STR:ass" : 0.2,
39+
"CTX:cog → STN:cog" : 1.0,
40+
"CTX:cog → THL:cog" : 0.1,
41+
42+
"CTX:mot → STR:mot" : 1.0,
43+
"CTX:mot → STR:ass" : 0.2,
44+
"CTX:mot → STN:mot" : 1.0,
45+
"CTX:mot → THL:mot" : 0.1,
46+
47+
"CTX:ass → STR:ass" : 1.0,
48+
49+
"STR:cog → GPi:cog" : -2.0,
50+
"STR:mot → GPi:mot" : -2.0,
51+
"STR:ass → GPi:cog" : -2.0,
52+
"STR:ass → GPi:mot" : -2.0,
53+
54+
"STN:cog → GPi:cog" : 1.0,
55+
"STN:mot → GPi:mot" : 1.0,
56+
57+
"GPi:cog → THL:cog" : -0.3,
58+
"GPi:mot → THL:mot" : -0.3,
59+
60+
"THL:cog → CTX:cog" : 0.4,
61+
"THL:mot → CTX:mot" : 0.4,
62+
63+
"CTX:mot → CTX:mot" : 0.5,
64+
"CTX:cog → CTX:cog" : 0.5,
65+
"CTX:ass → CTX:ass" : 0.5,
66+
67+
"CTX:cog → CTX:ass" : 0.025,
68+
"CTX:ass → CTX:mot" : 0.025,
69+
"CTX:mot → CTX:ass" : 0.01,
70+
"CTX:ass → CTX:cog" : 0.01
71+
}
72+
}

experiments/single-trial-motor.py

+44-15
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,21 @@
1212
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
1313

1414
def plot_freq(ax, ctx, thl, gpi, stn, str, i0, i1, duration, dt, title, xlabel, ylabel):
15+
16+
start, stop = duration
17+
duration = stop-start
1518
timesteps = np.linspace(0, duration, duration/dt)
16-
stn = stn[:duration:dt]
17-
str = str[:duration:dt]
18-
ctx = ctx[:duration:dt]
19-
gpi = gpi[:duration:dt]
20-
thl = thl[:duration:dt]
19+
#stn = stn[:duration:dt]
20+
#str = str[:duration:dt]
21+
#ctx = ctx[:duration:dt]
22+
#gpi = gpi[:duration:dt]
23+
#thl = thl[:duration:dt]
24+
25+
stn = stn[start:stop:dt]
26+
str = str[start:stop:dt]
27+
ctx = ctx[start:stop:dt]
28+
gpi = gpi[start:stop:dt]
29+
thl = thl[start:stop:dt]
2130

2231
fontsize = 8
2332

@@ -28,7 +37,7 @@ def plot_freq(ax, ctx, thl, gpi, stn, str, i0, i1, duration, dt, title, xlabel,
2837

2938
x = np.argwhere((ctx[:,i0] - ctx[:,i1]) > 40)[0] * dt
3039
ax.text(x, -6, "↑\nDecision", fontsize=8, va="top", ha="center")
31-
ax.text(500, -6, "↑\nTrial start", fontsize=8, va="top", ha="center")
40+
ax.text(0, -6, "↑\nTrial start", fontsize=8, va="top", ha="center")
3241

3342
ax.plot(timesteps, ctx[:,i1], c=colors[0], ls="--")
3443
ax.plot(timesteps, ctx[:,i0], c=colors[0])
@@ -64,7 +73,8 @@ def plot_freq(ax, ctx, thl, gpi, stn, str, i0, i1, duration, dt, title, xlabel,
6473
ax.set_xlabel("Time (ms)")
6574
if ylabel:
6675
ax.set_ylabel("Firing rate (spikes/s)")
67-
ax.set_xticks([0,2000])
76+
ax.set_xticks([0,duration])
77+
ax.set_xticklabels(["","%d" % duration])
6878
ax.set_ylim(-5,145)
6979
ax.set_yticks([0,40,80,120])
7080

@@ -81,12 +91,27 @@ def plot_scatter(X, y, color):
8191
Y = y*np.ones(len(X))
8292
ax.scatter(X, Y, s=.5, marker='|', facecolor=color, edgecolor="none")
8393

94+
start, stop = duration
95+
duration = stop-start
8496
timesteps = np.linspace(0, duration, duration/dt)
85-
stn = stn[:duration:dt]
86-
str = str[:duration:dt]
87-
ctx = ctx[:duration:dt]
88-
gpi = gpi[:duration:dt]
89-
thl = thl[:duration:dt]
97+
#stn = stn[:duration:dt]
98+
#str = str[:duration:dt]
99+
#ctx = ctx[:duration:dt]
100+
#gpi = gpi[:duration:dt]
101+
#thl = thl[:duration:dt]
102+
103+
stn = stn[start:stop:dt]
104+
str = str[start:stop:dt]
105+
ctx = ctx[start:stop:dt]
106+
gpi = gpi[start:stop:dt]
107+
thl = thl[start:stop:dt]
108+
109+
# timesteps = np.linspace(0, duration, duration/dt)
110+
# stn = stn[:duration:dt]
111+
# str = str[:duration:dt]
112+
# ctx = ctx[:duration:dt]
113+
# gpi = gpi[:duration:dt]
114+
# thl = thl[:duration:dt]
90115

91116
n = 10
92117
y = 100
@@ -133,7 +158,7 @@ def plot_scatter(X, y, color):
133158

134159

135160
def setup(task_filename="task.json", model_filename="model.json"):
136-
seed = 123
161+
seed = 12345
137162
np.random.seed(seed)
138163
random.seed(seed)
139164
model = Model(model_filename)
@@ -164,12 +189,15 @@ def simulate(task, model, loop, gpi=0):
164189
return ctx, thl, gpi, stn, str, i0, i1
165190

166191

167-
168192
dt = 10
169-
duration = 2000
193+
194+
duration = (500, 2000)
195+
# duration = (0, 4000)
196+
# duration = 4000
170197

171198
fig = plt.figure(figsize=(6,10))
172199

200+
#task, model = setup("task.json", "model-guthrie-no-warmup.json")
173201
task, model = setup("task.json", "model-guthrie.json")
174202

175203
ax = plt.subplot(3,1,1)
@@ -178,6 +206,7 @@ def simulate(task, model, loop, gpi=0):
178206
"A Motor channel (no cortical competition)", 0, 1)
179207
plot_raster(ax, ctx, thl, gpi, stn, str, i0, i1, duration, dt, "")
180208

209+
# task, model = setup("task.json", "model-noisy-no-warmup.json")
181210
task, model = setup("task.json", "model-noisy.json")
182211

183212
ax = plt.subplot(3,1,3)

0 commit comments

Comments
 (0)