from pathlib import Path
import numpy as np
import scipy.io
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt, iirnotch
from sklearn.cross_decomposition import CCA
DATA_DIR = Path("data")
STIM_FREQS = [9, 10, 12, 15]
N_HARMONICS = 2
SUB_BANDS = [(6, 14), (14, 22), (22, 30), (30, 40)]
WINDOW_LENGTHS = np.arange(1.0, 7.5, 0.5)
def make_filters(fs, low=5.0, high=40.0, line=50.0, order=4, notch_q=30):
bp_b, bp_a = butter(order, [low, high], btype="band", fs=fs)
n_b, n_a = iirnotch(line, notch_q, fs=fs)
return (bp_b, bp_a), (n_b, n_a)
def find_trials(y):
ch10 = y[9].astype(int)
active = (ch10 != 0).astype(int)
diff = np.diff(active)
starts = np.where(diff == 1)[0] + 1
ends = np.where(diff == -1)[0] + 1
return [(int(s), int(e), int(ch10[s])) for s, e in zip(starts, ends)]
def load_session(path):
mat = scipy.io.loadmat(path)
fs = int(mat["fs"][0, 0])
y = mat["y"]
bp, notch = make_filters(fs)
y_filt = y.astype(float).copy()
for ci in range(1, 9):
y_filt[ci] = filtfilt(*notch, filtfilt(*bp, y[ci]))
trials = find_trials(y)
n_samples = trials[0][1] - trials[0][0]
epochs = np.stack([y_filt[1:9, s:s + n_samples] for s, _, _ in trials])
labels = np.array([fr for _, _, fr in trials])
return epochs, labels, fs
def build_template(freq, n_samples, fs, n_harmonics=N_HARMONICS):
t = np.arange(n_samples) / fs
refs = []
for h in range(1, n_harmonics + 1):
refs.append(np.sin(2 * np.pi * h * freq * t))
refs.append(np.cos(2 * np.pi * h * freq * t))
return np.stack(refs, axis=1)
def cca_score(X, Y):
cca = CCA(n_components=1, max_iter=500)
cca.fit(X, Y)
Xc, Yc = cca.transform(X, Y)
return float(np.corrcoef(Xc.ravel(), Yc.ravel())[0, 1])
def fbcca_predict(epoch, fs, sub_bands=SUB_BANDS, a=1.25, b=0.25):
n = epoch.shape[1]
weights = np.array([(i + 1) ** -a + b for i in range(len(sub_bands))])
band_signals = []
for low, high in sub_bands:
b_, a_ = butter(4, [low, high], btype="band", fs=fs)
band_signals.append(filtfilt(b_, a_, epoch, axis=-1))
scores = np.zeros(len(STIM_FREQS))
for fi, freq in enumerate(STIM_FREQS):
Y = build_template(freq, n, fs, N_HARMONICS)
for bi, sig in enumerate(band_signals):
scores[fi] += weights[bi] * cca_score(sig.T, Y) ** 2
return STIM_FREQS[int(np.argmax(scores))]
def wolpaw_itr(p, N, T):
p = np.asarray(p, dtype=float)
bits = np.zeros_like(p)
middle = (p > 1.0 / N) & (p < 1.0)
bits[p >= 1.0] = np.log2(N)
bits[middle] = (np.log2(N) + p[middle] * np.log2(p[middle])
+ (1 - p[middle]) * np.log2((1 - p[middle]) / (N - 1)))
return bits * 60.0 / T
def sweep(epochs, labels, fs, lengths):
accs = []
for L in lengths:
n_L = int(L * fs)
truncated = epochs[:, :, :n_L]
preds = np.array([fbcca_predict(ep, fs) for ep in truncated])
accs.append((preds == labels).mean())
return np.array(accs)
SESSIONS = [
("S1 run 1", "subject_1_fvep_led_training_1.mat", "C0", "-"),
("S1 run 2", "subject_1_fvep_led_training_2.mat", "C0", "--"),
("S2 run 1", "subject_2_fvep_led_training_1.mat", "C1", "-"),
("S2 run 2", "subject_2_fvep_led_training_2.mat", "C1", "--"),
]
results = []
for name, fname, color, ls in SESSIONS:
epochs, labels, fs = load_session(DATA_DIR / fname)
accs = sweep(epochs, labels, fs, WINDOW_LENGTHS)
itrs = wolpaw_itr(accs, N=4, T=WINDOW_LENGTHS)
results.append({"name": name, "accs": accs, "itrs": itrs, "color": color, "ls": ls})
print(f" {name:<10} done — peak acc {accs.max():.0%}, peak ITR {itrs.max():.1f} bits/min")